{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RAG Using Azure AI Agent Service & Semantic Kernel\n",
    "\n",
    "This code snippet demonstrates how to create and manage an Azure AI agent for retrieval-augmented generation (RAG) using the `Azure AI Agent Service` and `Semantic Kernel`. The agent processes user queries based on the retrieved context and provides accurate responses accordingly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing the Environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SQLite Version Fix\n",
    "If you encounter the error:\n",
    "```\n",
    "RuntimeError: Your system has an unsupported version of sqlite3. Chroma requires sqlite3 >= 3.35.0\n",
    "```\n",
    "\n",
    "Uncomment this code block at the start of your notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install pysqlite3-binary\n",
    "# __import__('pysqlite3')\n",
    "# import sys\n",
    "# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing Packages\n",
    "The following code imports the necessary packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Azure imports for project client and credentials\n",
    "from azure.ai.projects import AIProjectClient\n",
    "from azure.ai.projects.models import FileSearchTool, FilePurpose, ToolSet\n",
    "from azure.identity import InteractiveBrowserCredential\n",
    "\n",
    "# Semantic Kernel imports\n",
    "from semantic_kernel.agents.azure_ai import AzureAIAgent\n",
    "from semantic_kernel.contents.utils.author_role import AuthorRole\n",
    "from semantic_kernel.functions.kernel_function_decorator import kernel_function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Project Client Initialization\n",
    "\n",
    "Initializes the `AIProjectClient` using an interactive browser credential and a connection string from environment variables. It then uploads a file to the project and creates a vector store for file search retrieval.\n",
    "\n",
    "Note: \n",
    "- You can find the connection string in the project settings.\n",
    "- The `document.md` file contains the travel documents that the AI agent will retrieve and augment with semantic search results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Uploaded file, file ID: assistant-VanQY1QAZQgSU2JDdM6nLD\n",
      "Created vector store, vector store ID: vs_LjASqHsFNuEAUPAKxRcVb6SH\n"
     ]
    }
   ],
   "source": [
    "project_client = AIProjectClient.from_connection_string(\n",
    "    credential=InteractiveBrowserCredential(),\n",
    "    conn_str=os.environ[\"PROJECT_CONNECTION_STRING\"],\n",
    ")\n",
    "\n",
    "uploaded_file = project_client.agents.upload_file_and_poll(\n",
    "    file_path=\"document.md\",\n",
    "    purpose=FilePurpose.AGENTS,\n",
    ")\n",
    "print(f\"Uploaded file, file ID: {uploaded_file.id}\")\n",
    "\n",
    "vector_store = project_client.agents.create_vector_store_and_poll(\n",
    "    file_ids=[uploaded_file.id],\n",
    "    name=\"vector_store\"\n",
    ")\n",
    "print(f\"Created vector store, vector store ID: {vector_store.id}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### File Search Tool and Toolset Creation\n",
    "\n",
    "This snippet creates a `FileSearchTool` using the vector store ID and adds it to a `ToolSet`. This toolset will be used by the AI agent for file search retrieval to provide relevant context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])\n",
    "toolset = ToolSet()\n",
    "toolset.add(file_search_tool)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prompt Plugin Definition\n",
    "This class defines a `PromptPlugin` with a kernel function to build an augmented prompt using the retrieved context. The function ensures that the AI agent only answers based on the provided context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PromptPlugin:\n",
    "    @kernel_function(\n",
    "        name=\"build_augmented_prompt\",\n",
    "        description=\"Build an augmented prompt using file search retrieval context.\"\n",
    "    )\n",
    "    def build_augmented_prompt(query: str, retrieval_context: str) -> str:\n",
    "        return (\n",
    "            f\"Retrieved Context:\\n{retrieval_context}\\n\\n\"\n",
    "            f\"User Query: {query}\\n\\n\"\n",
    "            \"IMPORTANT: Answer ONLY based on the above context. If the context does not provide enough information, respond with 'Insufficient context provided.'\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query Processing Function\n",
    "\n",
    "This asynchronous function processes a user query by creating a message, running the agent, and listing the responses. It prints the responses to the console."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def process_query(project_client, agent, thread_id: str, user_query: str):\n",
    "    print(f\"\\nProcessing query: {user_query}\")\n",
    "    message = project_client.agents.create_message(\n",
    "        thread_id=thread_id,\n",
    "        role=AuthorRole.USER,\n",
    "        content=user_query\n",
    "    )\n",
    "\n",
    "    run = project_client.agents.create_and_process_run(\n",
    "        thread_id=thread_id,\n",
    "        assistant_id=agent.id,\n",
    "    )\n",
    "\n",
    "    messages = project_client.agents.list_messages(\n",
    "        thread_id=thread_id,\n",
    "    )\n",
    "\n",
    "    print(\"Responses:\")\n",
    "    main_data = messages.get('data', [])\n",
    "    if main_data:\n",
    "        first_msg = main_data[0]\n",
    "        for content in first_msg.get(\"content\", []):\n",
    "            if content.get(\"type\") == \"text\":\n",
    "                print(\"Text:\", content[\"text\"].get(\"value\"))\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*60 + \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Main Function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created thread, thread ID: thread_AStMkmZVWuUpz2ELu5lvVKQm\n",
      "\n",
      "Processing query: Can you explain Contoso's travel insurance coverage?\n",
      "Responses:\n",
      "Text: Contoso's travel insurance coverage includes protection for medical emergencies, trip cancellations, and lost baggage【4:0†source】.\n",
      "\n",
      "============================================================\n",
      "\n",
      "\n",
      "Processing query: What is Neural Network?\n",
      "Responses:\n",
      "Text: The context does not contain any information about neural networks. Therefore, I cannot provide an explanation based on the uploaded documents.\n",
      "\n",
      "============================================================\n",
      "\n",
      "\n",
      "Cleaned up agent, thread, file, and vector store.\n"
     ]
    }
   ],
   "source": [
    "async def main():\n",
    "    async with AzureAIAgent.create_client(\n",
    "        credential=InteractiveBrowserCredential(),\n",
    "        conn_str=os.environ[\"PROJECT_CONNECTION_STRING\"],\n",
    "    ) as client:\n",
    "        # Define agent name and instructions tailored for RAG.\n",
    "        AGENT_NAME = \"RAGAgent\"\n",
    "        AGENT_INSTRUCTIONS = (\n",
    "            \"You are an AI assistant that provides accurate information based on the provided context. \"\n",
    "            \"Only use information from the retrieved context to answer questions. \"\n",
    "            \"If the context doesn't contain relevant information, clearly state that. \"\n",
    "            \"Provide concise and accurate responses.\"\n",
    "        )\n",
    "        \n",
    "        # Create agent definition.\n",
    "        agent_definition = await client.agents.create_agent(\n",
    "            model=\"gpt-4o-mini\",\n",
    "            name=AGENT_NAME,\n",
    "            instructions=AGENT_INSTRUCTIONS,\n",
    "            toolset=toolset,\n",
    "        )\n",
    "        \n",
    "        # Create the Azure AI Agent using the client and definition.\n",
    "        agent = AzureAIAgent(\n",
    "            client=client,\n",
    "            definition=agent_definition,\n",
    "        )\n",
    "        \n",
    "        # Add the PromptPlugin for retrieval-augmented generation.\n",
    "        agent.kernel.add_plugin(PromptPlugin(), plugin_name=\"promptPlugin\")\n",
    "        \n",
    "        # Create a conversation thread.\n",
    "        thread = await client.agents.create_thread()\n",
    "        print(f\"Created thread, thread ID: {thread.id}\")\n",
    "        \n",
    "        # Example user queries.\n",
    "        user_inputs = [\n",
    "            \"Can you explain Contoso's travel insurance coverage?\",  # Relevant context.\n",
    "            \"What is Neural Network?\"  # No relevant context.\n",
    "        ]\n",
    "        \n",
    "        try:\n",
    "            for user_query in user_inputs:\n",
    "                await process_query(project_client, agent, thread.id, user_query)\n",
    "        finally:\n",
    "            # Clean up resources.\n",
    "            await client.agents.delete_thread(thread.id)\n",
    "            await client.agents.delete_agent(agent.id)\n",
    "            project_client.agents.delete_file(uploaded_file.id)\n",
    "            project_client.agents.delete_vector_store(vector_store.id)\n",
    "            print(\"\\nCleaned up agent, thread, file, and vector store.\")\n",
    "\n",
    "await main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
